Conversation
|
Thanks for the PR! Some preliminary design questions and comments:
|
| self.register_to_config( | ||
| seq_len=seq_len, | ||
| num_inference_steps=num_inference_steps, | ||
| inject_start_token=inject_start_token, | ||
| ) |
There was a problem hiding this comment.
Generally we don't register default __call__ arguments to the config, but rather set them as default arguments to the __call__ method:
diffusers/src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Lines 744 to 752 in d4f97d1
| *, | ||
| batch_size: int = 1, |
There was a problem hiding this comment.
diffusers pipelines usually don't set __call__ arguments to be keyword-only. (That's not to say that there are no arguments for it, but because other pipelines allow positional arguments I think the expectation is that discrete diffusion pipelines will allow them as well.)
| if seq_len is None: | ||
| seq_len = int(self.config.seq_len) | ||
| if num_inference_steps is None: | ||
| num_inference_steps = int(self.config.num_inference_steps) | ||
| if inject_start_token is None: | ||
| inject_start_token = bool(self.config.inject_start_token) |
There was a problem hiding this comment.
Following up on #12911 (comment), this logic could be removed if we don't register default arguments to the config.
| if infill_mask is not None: | ||
| if infill_mask.shape != (batch_size, seq_len): | ||
| raise ValueError( | ||
| f"`infill_mask` must have shape {(batch_size, seq_len)}, got {tuple(infill_mask.shape)}." | ||
| ) |
There was a problem hiding this comment.
I think input checking and exceptions should be moved to a check_inputs method, which is the usual practice for diffusers pipelines:
diffusers/src/diffusers/pipelines/flux2/pipeline_flux2.py
Lines 686 to 693 in d4f97d1
| return int(token_id) | ||
| return None | ||
|
|
||
| def _init_latents( |
There was a problem hiding this comment.
We usually name methods which sample latents from the prior distribution prepare_latents:
| if hasattr(self.scheduler, "forward_process") and getattr(self.scheduler, "forward_process") == "uniform": | ||
| # Uniform prior over token IDs. Mirror scheduler's exclude-mask behavior. | ||
| if getattr(self.scheduler, "exclude_mask_from_uniform", False) and hasattr( | ||
| self.scheduler, "_sample_uniform_tokens" | ||
| ): | ||
| return self.scheduler._sample_uniform_tokens( | ||
| torch.Size((batch_size, seq_len)), | ||
| device=device, | ||
| dtype=torch.long, | ||
| generator=generator, | ||
| ) | ||
| vocab_size = int(getattr(self.scheduler, "vocab_size", 0)) | ||
| if vocab_size <= 0: | ||
| raise ValueError("Scheduler must define `vocab_size` for uniform prior sampling.") | ||
| return torch.randint( | ||
| 0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.long, generator=generator | ||
| ) |
There was a problem hiding this comment.
Suggestion: maybe it would be cleaner to define a scheduler method called (say) sample_prior which samples from the prior distribution based on the configured forward_process? So if self.forward_process == "uniform", we would call _sample_uniform_tokens under the hood in sample_prior to sample from a uniform prior distribution.
I think this would allow for more graceful support of other possible forward processes, and make the pipeline code cleaner (as most of the logic would be handled inside the scheduler).
|
|
||
| # 3. Prepare latents | ||
| input_ids = self.prepare_latents(batch_size, seq_len, generator=generator, device=device) | ||
| attention_mask = torch.ones_like(input_ids, dtype=torch.long) |
There was a problem hiding this comment.
It looks like attention_mask is currently not being used in this pipeline, is this expected?
| texts: list[str] | None = None | ||
|
|
||
|
|
||
| class TokenDiffusionPipeline(DiffusionPipeline, DiscreteDiffusionPipelineMixin): |
There was a problem hiding this comment.
Can you add a usage example for TokenDiffusionPipeline?
| texts: list[str] | None = None | ||
|
|
||
|
|
||
| class HybridTokenDiffusionPipeline(TokenDiffusionPipeline): |
There was a problem hiding this comment.
Can we inherit from DiffusionPipeline instead of TokenDiffusionPipeline and copy over common methods as necessary?
| texts: list[str] | None = None | ||
|
|
||
|
|
||
| class HybridTokenDiffusionPipeline(TokenDiffusionPipeline): |
There was a problem hiding this comment.
My understanding is that TokenDiffusionPipeline and HybridTokenDiffusionPipeline implement essentially the same logic (but are intended to be used with different schedulers). Would it be possible to consolidate these two pipelines into a single TokenDiffusionPipeline which works with both TokenDiffusionScheduler and HybridTokenDiffusionScheduler?
| cur_x = x[:, start:end].clone() | ||
| cur_position_ids = position_ids[:, start:end] | ||
| cur_attn_mask = attn_mask[start:end, :end].unsqueeze(0) |
There was a problem hiding this comment.
| cur_x = x[:, start:end].clone() | |
| cur_position_ids = position_ids[:, start:end] | |
| cur_attn_mask = attn_mask[start:end, :end].unsqueeze(0) | |
| block_x = x[:, start:end].clone() | |
| block_position_ids = position_ids[:, start:end] | |
| block_attn_mask = attn_mask[start:end, :end].unsqueeze(0) |
nit: rename block-level variables to use the prefix block (e.g. cur_x --> block_x) following the LLaDA 2 pipeline.
| num_blocks = (prompt_length + int(max_new_tokens) + int(block_length) - 1) // int(block_length) | ||
| total_length = int(num_blocks) * int(block_length) |
There was a problem hiding this comment.
| num_blocks = (prompt_length + int(max_new_tokens) + int(block_length) - 1) // int(block_length) | |
| total_length = int(num_blocks) * int(block_length) | |
| num_blocks = (prompt_length + max_new_tokens + block_length - 1) // block_length | |
| total_length = num_blocks * block_length |
Can we remove the int(...) casts here and elsewhere? It makes the code more readable and we are annotating max_new_tokens and block_length as ints so this should be safe.
| transfer_index = step_output.transfer_index | ||
| sampled_tokens = step_output.sampled_tokens | ||
| sampled_probs = step_output.sampled_probs |
There was a problem hiding this comment.
| transfer_index = step_output.transfer_index | |
| sampled_tokens = step_output.sampled_tokens | |
| sampled_probs = step_output.sampled_probs |
I think we can remove this as these variables are no longer being used.
| # Get model predictions only when p_x0 cache is invalidated | ||
| if p_x0_cache is None: | ||
| sigma_t = self.scheduler.compute_sigma(t, batch_size) | ||
| model_input = x_accum[:, -block_length:] |
There was a problem hiding this comment.
Would it be possible to refactor the block logic here to be more parallel with other block discrete diffusion pipelines such as LLaDA2Pipeline?
diffusers/src/diffusers/pipelines/llada2/pipeline_llada2.py
Lines 384 to 387 in e365d74
| ) | ||
|
|
||
| @classmethod | ||
| def from_pretrained( |
There was a problem hiding this comment.
Why do we need to override from_pretrained for the DFlash pipeline?
| texts: list[str] | None = None | ||
|
|
||
|
|
||
| def _build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> list[int]: |
There was a problem hiding this comment.
I think inlining this method would be more clear as we only call it once (in _get_target_layer_ids).
| self.draft_model.eval() | ||
| self.target_model.eval() |
There was a problem hiding this comment.
| self.draft_model.eval() | |
| self.target_model.eval() |
I think the draft_model and target_model should already be set to eval mode, so we don't need to explicitly call it here.
| return sequences, texts | ||
| return DFlashPipelineOutput(sequences=sequences, texts=texts) | ||
|
|
||
| def _get_block_size(self) -> int: |
There was a problem hiding this comment.
Do we need to support general draft and target models in _get_block_size and other methods below? My impression is that the DFlash model uses unique modeling logic, so it seems unlikely that we could drop in a random draft or target model and have it work out of the box. So I think it's reasonable to only support existing DFlash checkpoints such as z-lab/Qwen3-8B-DFlash-b16.
|
|
||
| The returned tensor is expected to be in (0, 1] and monotone decreasing in `t`. | ||
| """ | ||
| if self.alpha_schedule == "log_linear": |
There was a problem hiding this comment.
Can you give a reference which uses these
| noised = torch.where(block_mask.to(device=device), noised, original_samples) | ||
| return noised | ||
|
|
||
| def enforce_fixed_masks( |
There was a problem hiding this comment.
I think it might be better to inline enforce_fixed_masks in the pipeline code (e.g. in TokenDiffusionPipeline), as it is relatively simple and the choice of whether and how to enforce prefix/infill conditioning seems more like a pipeline design choice.
| self.vocab_size = int(vocab_size) | ||
| self.mask_token_id = int(mask_token_id) | ||
| self.num_train_timesteps = int(num_train_timesteps) | ||
| self.t_eps = float(t_eps) |
There was a problem hiding this comment.
| self.vocab_size = int(vocab_size) | |
| self.mask_token_id = int(mask_token_id) | |
| self.num_train_timesteps = int(num_train_timesteps) | |
| self.t_eps = float(t_eps) |
I think this is unnecessary as register_to_config should make these available as self.config.vocab_size, self.config.mask_token_id, etc.
| p_uniform = max(math.exp(-float(clip_noise)), float(p_uniform)) | ||
| log_B = float(gamma) * math.log(2.0) + math.log(p_uniform) - math.log(1.0 - p_uniform) | ||
| log_B = float(np.clip(log_B, -float(clip_noise), float(clip_noise))) | ||
| self.log_B = float(log_B) | ||
| self.log_gamma = float(math.log(float(gamma))) |
There was a problem hiding this comment.
| p_uniform = max(math.exp(-float(clip_noise)), float(p_uniform)) | |
| log_B = float(gamma) * math.log(2.0) + math.log(p_uniform) - math.log(1.0 - p_uniform) | |
| log_B = float(np.clip(log_B, -float(clip_noise), float(clip_noise))) | |
| self.log_B = float(log_B) | |
| self.log_gamma = float(math.log(float(gamma))) | |
| p_uniform = max(math.exp(clip_noise), p_uniform) | |
| log_B = gamma * math.log(2.0) + math.log(p_uniform) - math.log(1.0 - p_uniform) | |
| log_B = np.clip(log_B, -clip_noise, clip_noise) | |
| self.log_B = float(log_B) | |
| self.log_gamma = math.log(gamma) |
Can we remove all the int(...) and float(...) casts here and elsewhere?
|
|
||
| class HybridTokenDiffusionScheduler(SchedulerMixin, ConfigMixin): | ||
| """ | ||
| Hybrid-transition discrete token diffusion scheduler. |
There was a problem hiding this comment.
Can the __init__ arguments (such as p_uniform, clip_noise, gamma, etc.) be documented in the docstring here, including what they mean and what values might be reasonable for them?
| p_uniform: float = 0.0, | ||
| clip_noise: float = 20.0, |
There was a problem hiding this comment.
The default parameters here set the effective p_uniform to p_uniform corresponds to the maximum probability that tokens will transition to another token uniformly at random instead of to the mask token. Since TokenDiffusionScheduler)?
| elif noise_type == "cosine": | ||
| return 1.0 - (1.0 - eps) * torch.cos(t * math.pi / 2.0) |
There was a problem hiding this comment.
The alpha schedules here are different from the alpha schedules defined in TokenDiffusionScheduler, is this intended? See also #12911 (comment).
|
|
||
| def _compute_move_chance(self, t: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Compute the probability that a token has been masked (move chance) at continuous time *t*. |
There was a problem hiding this comment.
If I understand correctly, the move chance is given by move_chance and alpha_t here would be useful since it would be easier to compare this scheduler with similar schedulers such as TokenDiffusionScheduler.
| `torch.Tensor`: Move chance at each timestep value, same shape as *t*. | ||
| """ | ||
| noise_type = self.config.noise_type | ||
| eps = 1e-3 |
There was a problem hiding this comment.
I think eps should be configurable via __init__, following TokenDiffusionScheduler.
| # Compute move chances at t and s = t - dt | ||
| # ------------------------------------------------------------------ | ||
| move_chance_t = self._compute_move_chance(t).to(dtype=torch.float64) | ||
| move_chance_s = self._compute_move_chance(t - dt).to(dtype=torch.float64) |
There was a problem hiding this comment.
I think getting s from the timestep schedule would be better in case we want to support non-linspace timestep schedules.
| # Subs parameterization: mask token gets -inf, then log_softmax normalizes. | ||
| # For unmasked positions, the distribution is forced to be the identity. | ||
| logits[..., mask_token_id] = -1e9 | ||
| logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True) |
There was a problem hiding this comment.
My understanding is that this computes a log softmax, would using something like torch.special.log_softmax be better here? Also I think it might be more clear to rename this to something like log_probs.
| gumbel_noise = -(torch.rand_like(q_xs, generator=generator) + 1e-10).log() | ||
| gumbel_noise = (1e-10 + gumbel_noise).clamp(min=1e-30) | ||
| x_block = (q_xs / gumbel_noise).argmax(dim=-1) |
There was a problem hiding this comment.
Could we reuse the _gumbel_argmax function defined in scheduling_token_diffusion.py here?
| def step( | ||
| self, | ||
| draft_tokens: torch.LongTensor, | ||
| target_logits: torch.Tensor, |
There was a problem hiding this comment.
Would it be possible to refactor the DFlash step method such that it follows the standard step interface?
diffusers/src/diffusers/schedulers/scheduling_block_refinement.py
Lines 166 to 170 in e365d74
Pipelines: - TokenDiffusion: add usage example, remove unused attention_mask, add sample_prior to scheduler, inline enforce_fixed_masks - HybridTokenDiffusion: consolidate into thin wrapper over TokenDiffusion - SDAR: rename cur_x to block_x, remove int() casts, remove unused vars - DFlash: inline _get_target_layer_ids, remove eval() calls, remove from_pretrained override, simplify model support - BD3LM: refactor block logic parallel with LLaDA2 Schedulers: - TokenDiffusion: pre-compute alpha schedule, cleaner if/elif in step, add sample_prior method - HybridTokenDiffusion: remove redundant self.xxx assignments, document params, remove int/float casts - BD3LM: make eps configurable, document move_chance vs alpha_t, use log_softmax, get s from timestep schedule, reuse _gumbel_argmax - DFlash: refactor step to standard model_output interface
What does this PR do?
Add experimental support for discrete token diffusion methods and pipeline
moved llada2 to its own PR: #13226
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.